package main

import (
	"easy-proxy/balancer"
	"easy-proxy/conf"
	"easy-proxy/detect"
	"easy-proxy/node"
	"flag"
	"fmt"
	"io"
	"log"
	"net"
	"os"
	"strconv"
	"strings"
	"sync"
	"time"
)

var (
	mu sync.Mutex

	// 日志级别定义
	Info  *log.Logger
	Error *log.Logger

	// 负载均衡列表，initList来源于配置文件，nodeList动态变化
	initList []*node.Node
	nodeList []*node.Node

	cnf = flag.String("cnf", "", "--cnf=xxx.conf") // 配置文件路径

	port        string
	strategy    string // 负载均衡策略
	ticktime    int64  // 监控心跳间隔
	dialtimeout int64  // 连通性检测超时时长
	cluster     string // 后端集群类型
	detectgtdb  string
	detectmgr   string
	flagState   string // 是否检测节点可写状态
	user        string
	pass        string
)

var listChange chan int

func init() {
	flag.Parse()
	if *cnf == "" {
		fmt.Println("please input cnf")
		os.Exit(0)
	}

	// 读取配置文件
	cf := new(conf.Config)
	cf.InitConfig(*cnf)
	port = cf.Read("default", "port")
	strategy = cf.Read("default", "strategy")
	user = cf.Read("default", "user")
	pass = cf.Read("default", "pass")

	// 心跳时间间隔
	i, err := strconv.ParseInt(cf.Read("default", "ticktime"), 10, 64)
	if err != nil {
		Error.Println("Not support config for ticktime: ", err)
	}
	ticktime = i

	// 连通检测超时时间
	j, err := strconv.ParseInt(cf.Read("default", "dialtimeout"), 10, 64)
	if err != nil {
		Error.Println("Not support config for dialtimeout: ", err)
	}
	dialtimeout = j

	flagState = cf.Read("default", "flagState")
	cluster = cf.Read("default", "cluster")
	detectgtdb = cf.Read("default", "detectgtdb")
	detectmgr = cf.Read("default", "detectmgr")

	// 初始化默认配置列表
	initList = cf.ReadNode()

	// 初始化动态配置列表
	nodeList = cf.ReadNode()

	// 日志初始化
	file, err := os.OpenFile("log/easy.logs", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
	if err != nil {
		log.Fatalln("Failed to open log file: ", err)
	}
	Info = log.New(io.MultiWriter(file, os.Stderr), "Info: ", log.Ldate|log.Ltime)
	Error = log.New(io.MultiWriter(file, os.Stderr), "Error: ", log.Ldate|log.Ltime|log.Lshortfile)
}

func main() {
	Info.Println("The current load balance strategy is: ", strategy)
	// 启动监听，获取客户端请求
	l, err := net.Listen("tcp", fmt.Sprintf("%s:%s", "0.0.0.0", port))
	Info.Println("Listen port: ", port)
	if err != nil {
		Error.Println("Listen failed, please check port:", port)
		os.Exit(0)
	}

	// 监控节点状态，更新动态负载均衡列表
	listChange = make(chan int)

	go HeartBeat()

	if strategy == "random" { // 随机负载均衡策略处理
		random := &balancer.Random{}
		random.Nodes = nodeList
		for {
			sConn, err := l.Accept()
			if err != nil {
				Error.Println("Connection accept failed! ", err)
				continue
			}

			// 监控节点协程通知列表发生变化
			go func() {
				for {
					val := <-listChange
					if val == 1 {
						mu.Lock()
						random.Nodes = nodeList
						mu.Unlock()
					}
				}
			}()
			randomNode := random.Next()
			addr := randomNode.Ip + ":" + strconv.Itoa(randomNode.Port)

			go ForwardTCPRequest(addr, sConn)
		}
	} else if strategy == "rr" {
		// 轮询负载均衡策略处理
		rr := &balancer.RoundRobin{}
		rr.Nodes = nodeList
		for {
			sConn, err := l.Accept()
			if err != nil {
				Error.Println("Connection accept failed! ", err)
				continue
			}

			// 更新动态负载均衡列表
			go func() {
				for {
					val := <-listChange
					if val == 1 {
						mu.Lock()
						rr.Nodes = nodeList
						mu.Unlock()
					}
				}
			}()
			rrNode := rr.Next()
			addr := rrNode.Ip + ":" + strconv.Itoa(rrNode.Port)

			go ForwardTCPRequest(addr, sConn)
		}
	} else if strategy == "wrr" {
		// 加权轮询负载均衡策略处理
		wrr := &balancer.WeightRoundRobin{}
		wrr.Nodes = nodeList
		for {
			sConn, err := l.Accept()
			if err != nil {
				Error.Println("Connection accept failed! ", err)
				continue
			}

			// 更新动态负载均衡列表
			go func() {
				for {
					val := <-listChange
					if val == 1 {
						mu.Lock()
						wrr.Nodes = nodeList
						mu.Unlock()
					}
				}
			}()
			wrrNode := wrr.Next()
			addr := wrrNode.Ip + ":" + strconv.Itoa(wrrNode.Port)

			go ForwardTCPRequest(addr, sConn)
		}
	} else {
		Error.Println("Not support load balance strategy")
		os.Exit(0)
	}
}

/**
转发客户端/应用端请求
*/
func ForwardTCPRequest(addr string, sConn net.Conn) {
	defer sConn.Close()
	// 转发请求
	dTcpAddr, _ := net.ResolveTCPAddr("tcp4", addr)
	dConn, err := net.DialTCP("tcp", nil, dTcpAddr)
	if err != nil {
		Error.Println("Connection failed! ", err)
		_, err = sConn.Write([]byte("can't connect " + addr))
	} else {
		exitCH := make(chan bool, 1)

		// 把客户端的的请求转发给后端
		go func(s net.Conn, d *net.TCPConn, ex chan bool) {
			_, err := io.Copy(sConn, dConn)
			if err != nil {
				Error.Println("Send data failure: ", err)
			}
			exitCH <- true
		}(sConn, dConn, exitCH)

		// 把响应的数据返回给客户端
		go func(s net.Conn, d *net.TCPConn, ex chan bool) {
			_, err := io.Copy(dConn, sConn)
			if err != nil {
				Error.Println("Receive data failure: ", err)
			}
			exitCH <- true
		}(sConn, dConn, exitCH)

		// channel阻塞，读取连接关闭状态
		<-exitCH
        
		// 收到连接终止信息后，关闭连接
		_ = dConn.Close()
	}
}

/**
节点监控心跳，由参数ticktime控制心跳间隔时长
*/
func HeartBeat() {
	for {
		if len(nodeList) == 0 {
			Error.Println("There is no available node in load balance list")
			os.Exit(0)
		}
		for i := 0; i < len(initList); i++ {
			addr := net.JoinHostPort(initList[i].Ip, strconv.Itoa(initList[i].Port))
			// 默认检查端口连通性进行探活
			_, err := net.DialTimeout("tcp", addr, time.Duration(dialtimeout)*time.Millisecond)

			/**
			连通性检测后，
			通过，则检查节点是否可写
			不通过，则再次确认
			*/
			if err != nil {
				go CheckDialFailed(initList[i])
			} else {
				go CheckDialSucc(initList[i])
			}
		}
		time.Sleep(time.Duration(ticktime) * time.Millisecond)
	}
}

/**
检查节点失败操作
1、通知主线程更新负载均衡队列
2、通知主线程关闭与故障节点的连接
*/
func CheckDialFailed(n *node.Node) {
	// 连接失败再次重试
	time.Sleep(time.Duration(dialtimeout) * time.Millisecond)
	addr := net.JoinHostPort(n.Ip, strconv.Itoa(n.Port))

	// 再次连接确认
	_, err := net.DialTimeout("tcp", addr, time.Duration(dialtimeout)*time.Millisecond)
	if err != nil {
		Error.Println(addr, " was detected and could not connect :", err)
		DelNode(n)
	}
}

/**
检查节点成功操作
1、通知主线程更新负载均衡队列
*/
func CheckDialSucc(n *node.Node) {
	addr := net.JoinHostPort(n.Ip, strconv.Itoa(n.Port))
	addFlag := false
	if flagState == "1" {
		var sql string
		if cluster == "greatdb" {
			sql = detectgtdb
		} else if cluster == "mgr" {
			sql = detectmgr
		}

		/**
		实例可写验证
		可写验证失败，从负载均衡列表删除节点
		*/
		ok, err := detect.State(strings.Trim(sql, "\""), user, pass, strconv.Itoa(n.Port), n.Ip, n.Hostname, cluster)
		if !ok {
			Error.Println("instance error reporting, please check!", err)
			DelNode(n)
			addFlag = false
		} else {
			addFlag = true
		}
	} else {
		// 再次确认连通性
		_, err := net.DialTimeout("tcp", addr, time.Duration(dialtimeout)*time.Millisecond)
		if err == nil {
			addFlag = true
		} else {
			Error.Println(addr, " was detected and could not connect :", err)
		}
	}

	// 心跳监控检查到故障节点已经恢复，加入到负载均衡队列
	if addFlag {
		exists := false
		for _, value := range nodeList {
			if value.Ip == n.Ip && value.Port == n.Port {
				exists = true
				break
			}
		}
		if !exists {
			mu.Lock()
			nodeList = append(nodeList, n)
			listChange <- 1
			mu.Unlock()
			Info.Println("The destination address is added to the load balance list :", addr)
		}
	}
}

/**
心跳监控检查到故障节点，从负载均衡队列清除
*/
func DelNode(n *node.Node) {
	for i := 0; i < len(nodeList); i++ {
		if nodeList[i].Ip == n.Ip && nodeList[i].Port == n.Port {
			mu.Lock()
			nodeList = append(nodeList[:i], nodeList[i+1:]...)
			listChange <- 1
			mu.Unlock()
			Error.Println("The destination address is removed from the load balance list :", net.JoinHostPort(n.Ip, strconv.Itoa(n.Port)))
		}
	}
}
